import os

import numpy as np
import pandas as pd

from src.clustering.base_distances_bis import (
    norm_hamming,
    jaccard,
    rms_jaccard,
    geom_jaccard,
    rms_hamming,
    geom_hamming,
)
from src.imports.import_pabulib import get_files_names_from_folder

base_distances = {
    'jaccard': jaccard,
    'geom_jaccard': geom_jaccard,
    'norm_hamming': norm_hamming,
    'rms_jaccard': rms_jaccard,
    'geom_hamming': geom_hamming,
    'rms_hamming': rms_hamming,
}


if __name__ == "__main__":


    targets = ['warszawa_2023_districts','warszawa_2024_districts',
               'krakow_2023_districts', 'krakow_2024_districts',
                'lodz_2023_districts', 'lodz_2024_districts']

    for target in targets:
        print(target)
        plot_all_intra = {method: [] for method in base_distances}
        plot_all_inter = {method: [] for method in base_distances}
        plot_all_diff = {method: [] for method in base_distances}
        plot_all_sign = {method: [] for method in base_distances}
        plot_all_closest = {method: [] for method in base_distances}

        correct = {method: 0 for method in base_distances}

        instances = get_files_names_from_folder(f'data/pabulib/{target}')

        for instance_id in instances:

            ALL_INTRA = {method: [] for method in base_distances}
            ALL_INTER = {method: [] for method in base_distances}
            ALL_DIFF = {method: [] for method in base_distances}
            ALL_SIGN = {method: [] for method in base_distances}
            ALL_CLOSEST = {method: [] for method in base_distances}
            # for t in tqdm.tqdm(range(num_tests)):
            for method in base_distances:

                try:
                    file_path = f"output/pabulib/cat/{target}/{method}_distances_{instance_id}.csv"
                    file_path = os.path.join(os.getcwd(), file_path)
                    distances = pd.read_csv(file_path, delimiter=';')
                except:
                    continue

                if len(distances) == 0:
                    continue


                avg_intra_party_distance = np.mean(distances['intra'])
                avg_inter_party_distance = np.mean(distances['inter'])
                avg_diff = np.mean(distances['inter'] - distances['intra'])
                # avg_closest_friend_distance = np.mean(distances['closest'])


                tmp = distances['inter']/distances['intra']
                avg_sing = sum(1 for x in tmp if x > 1.01) / len(tmp)


                ALL_INTRA[method].append(avg_intra_party_distance)
                ALL_INTER[method].append(avg_inter_party_distance)
                ALL_DIFF[method].append(avg_diff)
                ALL_SIGN[method].append(avg_sing)
                # ALL_CLOSEST[method].append(avg_closest_friend_distance)

                plot_all_intra[method].append(np.mean(ALL_INTRA[method]))
                plot_all_inter[method].append(np.mean(ALL_INTER[method]))
                plot_all_diff[method].append(np.mean(ALL_DIFF[method]))
                plot_all_sign[method].append(np.mean(ALL_SIGN[method]))

        for method in base_distances:
            print(method,
                  round(np.mean(plot_all_sign[method]),4))
                  # round(np.std(plot_all_sign[method]),4))


